Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix]: properly deserialize tool_calls iterator before processing by mistral-common when MistralTokenizer is used #9951

Conversation

gcalmettes
Copy link
Contributor

@gcalmettes gcalmettes commented Nov 2, 2024

There is currently a bug in the Pydantic library where attributes declared as iterables are eagerly evaluated and therefore replaced in the instances by pydantic-core ValidatorIterator instance, when waiting to be consumed.

When the MistralTokenizer is used, no chat template is applied, and the request messages are directly sent to be processed by mistral-common. As a result, the tool_calls field, which is defined as Iterable in the Assitant message request object definition (both in the vllm CustomChatCompletionMessageParam or the official OpenAI ChatCompletionAssistantMessageParam) is not consumed and is sent as a pydantic ValidatorIterator to mistral-common :

{'role': 'assistant', 'content': None, 'tool_calls': ValidatorIterator(index=0, schema=Some(DefinitionRef(DefinitionRefValidator { definition: "typed-dict" })))}

The tool_calls field is then rightfully processed as an empty list by mistral-common, meaning that if any tool_calls is sent with an assistant message, they are ignored after processing by mistral-common:

AssistantMessage(role='assistant', content=None, tool_calls=[], prefix=False)

This causes the mistral-common validation check (here) to fail as both side are evaluated to false.

ERROR 09-29 17:36:44 serving_chat.py:153]   File "/conda/envs/vllm_env/lib/python3.12/site-packages/mistral_common/protocol/instruct/validator.py", line 147, in _validate_assistant_message
ERROR 09-29 17:36:44 serving_chat.py:153]     raise InvalidAssistantMessageException(
ERROR 09-29 17:36:44 serving_chat.py:153] mistral_common.exceptions.InvalidAssistantMessageException: Assistant message must have either content or tool_calls, but not both.

The bug is not seen when chat templates are used, since the tools_calls iterator is consumed in the template when looping over each tool_call.

The bug is known on Pydantic side, and it indeed particularly affects the tool_calls field for LLM-based workloads using the OpenAI client (see this issue for exemple).

This PR makes the tool_calls received in the request for the assistant messages being consumed (and therefore valided) when the MistralTokenizer is used.

FIX #9059

Copy link

github-actions bot commented Nov 2, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Nov 2, 2024
@gcalmettes gcalmettes force-pushed the feat/support-tool-parsing-for-mistral-tokenizer branch 3 times, most recently from 34183bb to 9306944 Compare November 2, 2024 11:52
@gcalmettes gcalmettes changed the title [BugFix]: deserialize tool_calls before processing by mistral-common when MistralTokenizer is used [BugFix]: consume and validatetool_calls iterator before processing by mistral-common when MistralTokenizer is used Nov 2, 2024
@gcalmettes gcalmettes changed the title [BugFix]: consume and validatetool_calls iterator before processing by mistral-common when MistralTokenizer is used [BugFix]: consume and validate tool_calls iterator before processing by mistral-common when MistralTokenizer is used Nov 2, 2024
@gcalmettes gcalmettes changed the title [BugFix]: consume and validate tool_calls iterator before processing by mistral-common when MistralTokenizer is used [BugFix]: properly deserialize tool_calls iterator before processing by mistral-common when MistralTokenizer is used Nov 4, 2024
@gcalmettes gcalmettes force-pushed the feat/support-tool-parsing-for-mistral-tokenizer branch from 1e8c3b1 to 8aa0146 Compare November 4, 2024 08:13
@patrickvonplaten
Copy link
Contributor

Thanks a bunch for the PR! @gcalmettes do you have an easy reproducible on this error by any chance? (which is fixed by this PR?)

@gcalmettes
Copy link
Contributor Author

gcalmettes commented Nov 8, 2024

Hi @patrickvonplaten,

Sure, please find below a code that breaks (short version for isolated example and long turn by turn version).
The short version uses dict as payload for the messages, the long version uses the native OpenAI ChatCompletionMessage object. Both break for the same root cause, the tool_calls of the assistant message is not deserialized properly before being sent to the Mistral Tokenizer, and therefore both content and tool_calls are empty.

Short version:

"""
# using the latest vllm code:
# pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl

vllm serve mistralai/Pixtral-12B-2409 \
           --tokenizer-mode mistral \
           --limit-mm-per-prompt image=5 \
           --enable-auto-tool-choice \
           --tool-call-parser=mistral
"""

from openai import OpenAI

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

messages = [
    {'role': 'user', 'content': 'Can you tell me what the temperate will be in Dallas, in fahrenheit?'},
    {'content': None, 'role': 'assistant', 'tool_calls': [{'id': 'RBS92VTjJ', 'function': {'arguments': '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}', 'name': 'get_current_weather'}, 'type': 'function'}]},
    {'role': 'tool', 'content': "The weather in Dallas, TX is 85 degrees fahrenheit. It is partly cloudly, with highs in the 90's.", 'tool_call_id': 'n3OMUpydP'}
]
chat_completion = client.chat.completions.create(messages=messages,
                                                 model=model)

Turn by turn (long) version:

It's basically an adaptation of your offline_chat_with_tool_exemple example, but querying the vllm openai frontend.

The problem occurs when the assistant message gets a parsed tool_call.
I have also put the error that we get without this PR at the end.

"""
# using the latest vllm code:
# pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl

vllm serve mistralai/Pixtral-12B-2409 \
           --tokenizer-mode mistral \
           --limit-mm-per-prompt image=5 \
           --enable-auto-tool-choice \
           --tool-call-parser=mistral
"""

import json
import string
import random

from openai import OpenAI
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall

# will be used to ensure the length=9 constraint on tool call ids
def generate_random_id(length=9):
    characters = string.ascii_letters + string.digits
    random_id = ''.join(random.choice(characters) for _ in range(length))
    return random_id

# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'):
    return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
            "partly cloudly, with highs in the 90's.")

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8004/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

tool_functions = {"get_current_weather": get_current_weather}

tools = [{
    "type": "function",
    "function": {
        "name": "get_current_weather",
        "description": "Get the current weather in a given location",
        "parameters": {
            "type": "object",
            "properties": {
                "city": {
                    "type":
                    "string",
                    "description":
                    "The city to find the weather for, e.g. 'San Francisco'"
                },
                "state": {
                    "type":
                    "string",
                    "description":
                    "the two-letter abbreviation for the state that the city is"
                    " in, e.g. 'CA' which would mean 'California'"
                },
                "unit": {
                    "type": "string",
                    "description": "The unit to fetch the temperature in",
                    "enum": ["celsius", "fahrenheit"]
                }
            },
            "required": ["city", "state", "unit"]
        }
    }
}]

messages = [{
    "role":
    "user",
    "content":
    "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]

chat_completion = client.chat.completions.create(messages=messages,
                                                 model=model,
                                                 tools=tools)

# parse tool_call response and execute the function call
assistant_message = chat_completion.choices[0].message
output = assistant_message.content
tool_calls = json.loads(output)
tool_call_id = generate_random_id()
tool_answers = [
    tool_functions[call['name']](**call['arguments']) for call in tool_calls
]

# Add assistant message and tool message to the converstation

# Simulate the tool call was automatically parsed
# (the Mistral Tokenizer prevent the [TOOL_CALLS] token to be generated so the tool_call cannot be automatically detected by the --tool-parser mistral)
assistant_message.content = None
assistant_message.tool_calls = [
        ChatCompletionMessageToolCall.parse_obj({
            'id': tool_call_id,
            "function": {
                'arguments': json.dumps(tool_calls[0]["arguments"]),
                'name': tool_calls[0]["name"]
            },
            'type': 'function',
        })
    ]

# append the assistant message
messages.append(assistant_message)

# append the answer as a tool message
messages.append({
    "role": "tool",
    "content": "\n\n".join(tool_answers),
    "tool_call_id": generate_random_id(),
})

# Send everything to the LLM to give a final answer
# THIS IS IN THIS REQUEST THAT THE CODE WILL BREAK
chat_completion = client.chat.completions.create(messages=messages,
                                                 model=model,
                                                 tools=tools)

print(chat_completion.choices[0].message.content)

without this PR, the code breaks on the server side with:

ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/applications.py", line 113, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 187, in __call__
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 715, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 735, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 73, in app
    response = await f(request)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 301, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 212, in run_endpoint_function
    return await dependant.call(**values)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 338, in create_chat_completion
    generator = await handler.create_chat_completion(request, raw_request)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/serving_chat.py", line 131, in create_chat_completion
    ) = await self._preprocess_chat(
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/serving_engine.py", line 458, in _preprocess_chat
    request_prompt = apply_mistral_chat_template(
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/chat_utils.py", line 767, in apply_mistral_chat_template
    return tokenizer.apply_chat_template(
  File "/usr/local/lib/python3.10/dist-packages/vllm/transformers_utils/tokenizers/mistral.py", line 215, in apply_chat_template
    encoded = self.mistral.encode_chat_completion(request)
  File "/usr/local/lib/python3.10/dist-packages/mistral_common/tokens/tokenizers/mistral.py", line 174, in encode_chat_completion
    validated_request = self._chat_completion_request_validator.validate_request(request)
  File "/usr/local/lib/python3.10/dist-packages/mistral_common/protocol/instruct/validator.py", line 63, in validate_request
    self.validate_messages(request.messages)
  File "/usr/local/lib/python3.10/dist-packages/mistral_common/protocol/instruct/validator.py", line 51, in validate_messages
    self._validate_message_list_content(messages)
  File "/usr/local/lib/python3.10/dist-packages/mistral_common/protocol/instruct/validator.py", line 273, in _validate_message_list_content
    self._validate_assistant_message(message, is_last_message=idx == len(messages) - 1)
  File "/usr/local/lib/python3.10/dist-packages/mistral_common/protocol/instruct/validator.py", line 147, in _validate_assistant_message
    raise InvalidAssistantMessageException(
mistral_common.exceptions.InvalidAssistantMessageException: Assistant message must have either content or tool_calls, but not both.

@patrickvonplaten
Copy link
Contributor

Nice catch! Just verified this and this fix is indeed needed to correctly parse tool calls. @ywang96 @DarkLight1337 do you think we could get something like this merged?

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 14, 2024 03:06
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2024
@DarkLight1337 DarkLight1337 merged commit 52b48c1 into vllm-project:main Nov 14, 2024
70 checks passed
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 20, 2024
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: rickyx <[email protected]>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: "--tokenizer-mode", "mistral" not compatible with openai API tool use tests
3 participants